// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "test/common/GoldModel.h"

#include <ac_std_float.h>

#include <fstream>

#include "src/ArchitectureParams.h"
#include "test/common/operation.proto.h"

ac::bfloat16 fma(ac::bfloat16 input, ac::bfloat16 weight, ac::bfloat16 psum) {
  return input.fma<AC_RND_CONV, false>(weight, psum);
}

template <typename DTYPE>
DTYPE fma(DTYPE input, DTYPE weight, DTYPE psum) {
  return input * weight + psum;
}

void run_tensor_operation(
    const third_party::dnn_accelerator_hls::test::common::TensorOperation
        &operation,
    INPUT_DATATYPE *matrixA, INPUT_DATATYPE *matrixB,
    OUTPUT_DATATYPE *matrixC) {
  if (operation.has_matrix_multiplication()) {
    const third_party::dnn_accelerator_hls::test::common::
        MatrixMultiplicationOperation matMulOp =
            operation.matrix_multiplication();
    const third_party::dnn_accelerator_hls::test::common::Reshapes reshapes =
        operation.reshapes();

    for (int ay = 0; ay < matMulOp.ay(); ay++) {
      for (int bx = 0; bx < matMulOp.bx(); bx++) {
        OUTPUT_DATATYPE acc = 0;
        for (int ax = 0; ax < matMulOp.ax(); ax++) {
          int matrixAIndex = ay * matMulOp.ax() + ax;
          if (reshapes.concatenate_heads()) {
            matrixAIndex = (matMulOp.ax() / reshapes.head_size()) *
                               matMulOp.ay() * reshapes.head_size() +
                           (matMulOp.ax() % reshapes.head_size());
          }

          int matrixBIndex = ax * matMulOp.bx() + bx;
          if (reshapes.transpose_matrix_b()) {
            matrixBIndex = bx * matMulOp.ax() + ax;
          }

          acc = fma(matrixA[matrixAIndex], matrixB[matrixBIndex], acc);
        }

        if (operation.has_scaling()) {
          OUTPUT_DATATYPE scale_factor = operation.scaling().scale_factor();
          acc *= scale_factor;
        }

        if (operation.activation_function().relu()) {
          if (acc < 0) acc = 0;
        }

        int matrixCIndex = ay * matMulOp.bx() + bx;
        if (reshapes.split_heads()) {
          matrixCIndex =
              (bx / reshapes.head_size()) * ay * reshapes.head_size() +
              (bx % reshapes.head_size());
        }
        matrixC[matrixCIndex] = acc;
      }
    }
  } else {
    const third_party::dnn_accelerator_hls::test::common::ConvolutionOperation
        convOp = operation.convolution();

    for (int ox = 0; ox < convOp.ox(); ox++) {
      for (int oy = 0; oy < convOp.oy(); oy++) {
        for (int oc = 0; oc < convOp.oc(); oc++) {
          OUTPUT_DATATYPE acc = 0;
          for (int ic = 0; ic < convOp.ic(); ic++) {
            for (int fx = -(convOp.fx() - 1) / 2; fx < (convOp.fx() - 1) / 2;
                 fx++) {
              for (int fy = -(convOp.fy() - 1) / 2; fy < (convOp.fy() - 1) / 2;
                   fy++) {
                // padding
                if (convOp.stride() * ox + fx < 0 ||
                    convOp.stride() * oy + fy < 0 ||
                    convOp.stride() * ox + fx >=
                        convOp.stride() * convOp.ox() ||
                    convOp.stride() * oy + fy >=
                        convOp.stride() * convOp.oy()) {
                  // do nothing
                } else {
                  int matrixAIndex =
                      (oy * convOp.stride() + fy) *
                          (convOp.ox() * convOp.stride()) * convOp.ic() +
                      (ox * convOp.stride() + fx) * convOp.ic() + ic;
                  int matrixBIndex = ((fy + (convOp.fy() - 1) / 2) *
                                      convOp.fx() * convOp.ic() * convOp.oc()) +
                                     ((fx + (convOp.fx() - 1) / 2) *
                                      convOp.ic() * convOp.oc()) +
                                     (ic * convOp.oc()) + oc;
                  acc = fma(matrixA[matrixAIndex], matrixB[matrixBIndex], acc);
                }
              }
            }
          }
          if (operation.activation_function().relu()) {
            if (acc < 0) acc = 0;
          }

          int matrixCIndex = oy * convOp.ox() + ox * convOp.oc() + oc;
          matrixC[matrixCIndex] = acc;
        }
      }
    }
  }
}

void run_softmax(
    const third_party::dnn_accelerator_hls::test::common::Softmax &softmaxOp,
    INPUT_DATATYPE *matrixA, OUTPUT_DATATYPE *matrixC) {
  for (int y = 0; y < softmaxOp.tensor_height(); y++) {
    // calculate max(x)
    INPUT_DATATYPE max = 0;
    for (int x = 0; x < softmaxOp.tensor_width(); x++) {
      if (matrixA[y * softmaxOp.tensor_width() + x] > max) max = x;
    }

    // calculate sum(exp(x-max))
    INPUT_DATATYPE sum = 0;
    for (int x = 0; x < softmaxOp.tensor_width(); x++) {
      INPUT_DATATYPE exponential =
          exp((matrixA[y * softmaxOp.tensor_width() + x] - max).to_float());
      sum += exponential;
      matrixC[y * softmaxOp.tensor_width() + x] = exponential;
    }

    // divide by sum
    for (int x = 0; x < softmaxOp.tensor_width(); x++) {
      matrixC[y * softmaxOp.tensor_width() + x] /= sum;
    }
  }
}

void run_layernorm(
    const third_party::dnn_accelerator_hls::test::common::LayerNorm
        &layerNormOp,
    INPUT_DATATYPE *matrixA, OUTPUT_DATATYPE *matrixC) {
  // calculate sum
  INPUT_DATATYPE sum = 0;
  for (int y = 0; y < layerNormOp.tensor_height(); y++) {
    for (int x = 0; x < layerNormOp.tensor_width(); x++) {
      sum += matrixA[y * layerNormOp.tensor_width() + x];
    }
  }

  // calculate mean
  INPUT_DATATYPE mean =
      sum / (layerNormOp.tensor_height() * layerNormOp.tensor_width());

  // calculate variance
  sum = 0;
  for (int y = 0; y < layerNormOp.tensor_height(); y++) {
    for (int x = 0; x < layerNormOp.tensor_width(); x++) {
      INPUT_DATATYPE normalized =
          matrixA[y * layerNormOp.tensor_width() + x] - mean;
      sum += normalized * normalized;
    }
  }
  INPUT_DATATYPE variance =
      sum / (layerNormOp.tensor_height() * layerNormOp.tensor_width());

  INPUT_DATATYPE std_dev = std::sqrt(variance.to_float());

  for (int y = 0; y < layerNormOp.tensor_height(); y++) {
    for (int x = 0; x < layerNormOp.tensor_width(); x++) {
      matrixC[y * layerNormOp.tensor_width() + x] =
          (matrixA[y * layerNormOp.tensor_width() + x] - mean) / std_dev;
    }
  }
}

void run_gold_op(
    const third_party::dnn_accelerator_hls::test::common::Operation &operation,
    INPUT_DATATYPE *matrixA, INPUT_DATATYPE *matrixB,
    OUTPUT_DATATYPE *matrixC) {
  std::cout << "Running gold model " << std::endl;

  if (operation.has_tensor_operation()) {
    run_tensor_operation(operation.tensor_operation(), matrixA, matrixB,
                         matrixC);
  } else if (operation.has_softmax()) {
    run_softmax(operation.softmax(), matrixA, matrixC);
  } else if (operation.has_layer_norm()) {
    run_layernorm(operation.layer_norm(), matrixA, matrixC);
  } else {
    LOG(FATAL) << "Invalid or unmapped operation. Please double check the "
                  "textproto.";
  }
}
